Skip to content

Conversation

@vasqu
Copy link
Contributor

@vasqu vasqu commented Oct 24, 2025

Non-vmap creation of masks. These work with all our base masks and we only default back to vmap when using patterns we cannot guarantee (i.e. additional and/or masks).

Note:

  • Non-vmap works with every mask that has anything index based
  • Merged old/new sdpa under one function --> easier maintenance imo
  • Executorch does not need an additional masking fn anymore
  • Lifts some restrictions on older torch versions, e.g. chunked attn with padding, packed attn masks etc

Fixes #41639

cc @jiqing-feng @IlyasMoutawwakil

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@vasqu vasqu changed the title [WIP][Masking] Non-vmap default for attention masks [Attn Masks] Non-vmap default for attention masks Oct 29, 2025
@vasqu vasqu marked this pull request as ready for review October 29, 2025 11:06
return cache


def sdpa_mask_without_vmap(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No longer needed as vmap was the reason we needed this workaround in the first place

NOTE: It is important to keep an index-based version for non-vmap expansion.
"""
return q_idx.new_ones((), dtype=torch.bool)
return q_idx >= 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As noted above, for non-vmap we need this as index based version

causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True)
return causal_mask

attention_mask = attention_mask | torch.all(~attention_mask, dim=-1, keepdim=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I encountered issues with the inplace version where we'd need a clone (e.g. when using swa). This is safer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MarianMTModel performance regression due to Bidirectional masks

2 participants